from spinup.algos.pytorch.ppo.ppo import ppo
from env.mujoco_env.reacher_env import ReacherGymEnv
from env.mujoco_env.reacher_env import RMReacherGymEnv
from env.mujoco_env.reacher_env import ReacherGymEnvEval
import os
import torch
import sys


def train(model):
    if model == 'lof':
        # env = ReacherGymEnv({'headless': True, 'horizon': 100})

        ac_kwargs = {
            'hidden_sizes': (128, 128),
            'activation': torch.nn.Tanh
        }


        logger_kwargs = {
            "output_dir": os.path.join(os.environ['LOF_PKG_PATH'], 'experiments', 'red'),
            "exp_name": "rm"
        }

        ppo(ReacherGymEnv,
            logger_kwargs=logger_kwargs,
            ac_kwargs=ac_kwargs,
            clip_ratio=0.1,
            epochs=1001,
            pi_lr=1e-4,
            vf_lr=1e-3,
            lam=0.99,
            save_freq=50
        )

    elif model == 'rm':
        num_epochs = 1000

        ac_kwargs = {
            'hidden_sizes': (128, 128),
            'activation': torch.nn.Tanh
        }

        nFs = [7, 5, 5, 3]
        task_names = ['composite', 'sequential', 'IF', 'OR']
        for nF, task_name in zip(nFs, task_names):
            # load_path = os.path.join(os.environ['LOF_PKG_PATH'], 'experiments', 'rm', task_name, 'pyt_save', 'model990.pt')
            # reload_ac = torch.load(load_path)

            logger_kwargs = {
                "output_dir": os.path.join(os.environ['LOF_PKG_PATH'], 'experiments', 'rm', task_name),
                "exp_name": "rm"
            }

            ppo(lambda : RMReacherGymEnv(nF=nF, task_name=task_name, training=True),
                logger_kwargs=logger_kwargs,
                ac_kwargs=ac_kwargs,
                clip_ratio=0.1,
                epochs=num_epochs,
                pi_lr=1e-4,
                vf_lr=1e-3,
                lam=0.99,
                save_freq=10
                # reload_ac=reload_ac,
                # start_epoch=990
            )

if __name__ == "__main__":
    # model needs to be either 'lof' or 'rm'
    # default is lof
    if len(sys.argv) == 1:
        model = 'lof'
    else:
        model = sys.argv[1]
    train(model)